#include <curand_kernel.h>

template <typename T>
__device__ void dropout_forward(curandState *state, int len,
    T *input, T *rand_vals, T ratio, T scale) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < len) {
    rand_vals[idx] = curand_uniform(&state[idx]);
    input[idx] = input[idx] * (rand_vals[idx] > ratio) * scale;
  }
}
template <typename T>
__device__ void dropout_backward(curandState *state, int len,
    T *grad, T *rand_vals, T ratio, T scale) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < len) {
    grad[idx] = grad[idx] * (rand_vals[idx] > ratio) * scale;
  }
}

#define DEF_DROPOUT_API(NAME, DTYPE) \
  __global__ void NAME ## _ ## DTYPE (curandState *state, int len, \
    DTYPE *grad, DTYPE *rand_vals, DTYPE ratio, DTYPE scale) { \
      NAME (state, len, grad, rand_vals, ratio, scale); \
  }

extern "C" {
__global__ void dropout_init(curandState *state, int len) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < len)
    curand_init(3085, idx, 0, &state[idx]);
}

__global__ void dropout_alloc_size(double *psize) {
  *psize = (double)sizeof(curandState);
}

DEF_DROPOUT_API(dropout_forward, float)
DEF_DROPOUT_API(dropout_forward, double)
DEF_DROPOUT_API(dropout_backward, float)
DEF_DROPOUT_API(dropout_backward, double)

} // extern "C"

// vim: ft=cuda
